import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset, TensorDataset, Dataset
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from typing import List, Dict
from sklearn.neighbors import kneighbors_graph
from PIL import Image
import os

def gaussian_kernel(A: torch.Tensor, B: torch.Tensor, sigma: float = 1/(3*64*64)) -> torch.Tensor:
    """
    Computes the Gaussian kernel between two sets of points.

    The Gaussian kernel is a popular kernel function used in machine learning and statistics.
    It measures the similarity between two points in a high-dimensional space.

    Parameters:
    - A (torch.Tensor): A tensor of shape (n, d) representing n points in d-dimensional space.
    - B (torch.Tensor): A tensor of shape (m, d) representing m points in d-dimensional space.
    - sigma (float, optional): The standard deviation of the Gaussian kernel. Default is 1.0.

    Returns:
    - torch.Tensor: A tensor of shape (n, m) representing the Gaussian kernel values between points in A and B.
    """
    diff = A.unsqueeze(1) - B.unsqueeze(0)
    return torch.exp(-sigma * torch.norm(diff, dim=2) ** 2)


def rational_quadratic_kernel(A: torch.Tensor, B: torch.Tensor, alpha: float = 2.0, scale: float = 1/(3 * 64 * 64)) -> torch.Tensor:
    """
    Computes the Rational Quadratic kernel between two sets of points.

    Parameters:
    - A (torch.Tensor): A tensor of shape (n, d) representing n points in d-dimensional space.
    - B (torch.Tensor): A tensor of shape (m, d) representing m points in d-dimensional space.
    - alpha (float, optional): The shape parameter of the RQ kernel. Default is 1.0.
    - scale (float, optional): The length scale parameter of the RQ kernel. Default is 1.0.

    Returns:
    - torch.Tensor: A tensor of shape (n, m) representing the Rational Quadratic kernel values between points in A and B.
    """
    diff_sq = torch.sum((A.unsqueeze(1) - B.unsqueeze(0))**2, dim=2)
    return (1 + diff_sq * scale / (2 * alpha))**(-alpha)


def kneighbors_graph_torch(X: torch.Tensor, neighbors: int) -> torch.Tensor:
    """
    Build a k-nearest-neighbor graph in PyTorch (GPU-capable).
    Returns adjacency matrix [batch, batch].
    """
    # Compute pairwise squared Euclidean distance
    dist = torch.cdist(X, X, p=2)   # [batch, batch], GPU accelerated

    # Get indices of k smallest distances (excluding self)
    knn_idx = dist.argsort(dim=1)[:, 1:neighbors+1]

    # Build adjacency matrix
    batch = X.shape[0]
    device = X.device
    N_X = torch.zeros((batch, batch), device=device)

    row_idx = torch.arange(batch, device=device).unsqueeze(1).repeat(1, neighbors)
    N_X[row_idx, knn_idx] = 1.0  # mark neighbors with 1

    return N_X



def ECMMD(Z, Y, X, kernel, neighbors: int):
    batch = X.shape[0]
    device = X.device 

    N_X = kneighbors_graph_torch(X, neighbors)

    # Compute kernel matrices (must be PyTorch ops!)
    kernel_ZZ = kernel(Z, Z)
    kernel_YY = kernel(Y, Y)
    kernel_ZY = kernel(Z, Y)
    kernel_YZ = kernel(Y, Z)

    # Compute H matrix
    H = kernel_ZZ + kernel_YY - kernel_ZY - kernel_YZ

    return torch.sum(H * N_X) / (batch * neighbors)


def prepare_data(dataset, noise_factor):
    true_images = []
    noisy_images = []
    for  image  in dataset:
      if image is not None:
        image = image.squeeze(0)
        noisy_image = image.clone() + torch.randn_like(image) * noise_factor
        noisy_image = torch.clip(noisy_image, 0., 1.)
    
        true_images.append(image)
        noisy_images.append(noisy_image)
    return torch.stack(noisy_images), torch.stack(true_images)

def display_images(noisy_images, true_images, num_images=5):
    fig, axes = plt.subplots(2, num_images, figsize=(5, 3))
    # Move tensors to CPU for plotting with matplotlib
    noisy_images_cpu = noisy_images.cpu()
    true_images_cpu = true_images.cpu()
    for i in range(num_images):
        axes[0, i].imshow(noisy_images_cpu[i].permute(1, 2, 0))
        axes[0, i].axis('off')
        axes[1, i].imshow(true_images_cpu[i].permute(1, 2, 0))
        axes[1, i].axis('off')
    plt.show()